"""Convert a GeoArrow array to interleaved representation."""

import numpy as np
from arro3.core import (
    Array,
    ChunkedArray,
    DataType,
    Field,
    Table,
    fixed_size_list_array,
    struct_field,
)

from lonboard._constants import EXTENSION_NAME
from lonboard._geoarrow.extension_types import CoordinateDimension, coord_storage_type
from lonboard._geoarrow.ops.reproject import (
    _map_coords_nest_0,
    _map_coords_nest_1,
    _map_coords_nest_2,
    _map_coords_nest_3,
)
from lonboard._utils import get_geometry_column_index


def make_geometry_interleaved(
    table: Table,
) -> Table:
    """Convert geometry column in table from struct to interleaved coordinate layout."""
    geom_col_idx = get_geometry_column_index(table.schema)
    # No geometry column in table
    if geom_col_idx is None:
        return table

    geom_field = table.schema.field(geom_col_idx)
    geom_column = table.column(geom_col_idx)

    # The GeoArrow box extension type is only struct, not interleaved. It will be
    # converted to an interleaved polygon separately, if needed.
    if geom_field.metadata.get(b"ARROW:extension:name") == EXTENSION_NAME.BOX:
        return table

    new_field, new_column = convert_struct_column_to_interleaved(
        field=geom_field,
        column=geom_column,
    )
    return table.set_column(geom_col_idx, new_field, new_column)


def convert_struct_column_to_interleaved(
    *,
    field: Field,
    column: ChunkedArray,
) -> tuple[Field, ChunkedArray]:
    """Convert a GeoArrow column from struct to interleaved coordinate layout."""
    extension_type_name = field.metadata[b"ARROW:extension:name"]

    new_chunked_array = _convert_column(column, extension_type_name=extension_type_name)
    return field.with_type(new_chunked_array.type), new_chunked_array


def _convert_column(
    column: ChunkedArray,
    *,
    extension_type_name: bytes,
) -> ChunkedArray:
    if extension_type_name == EXTENSION_NAME.POINT:
        func = _transpose_chunk_nest_0
    elif extension_type_name in [EXTENSION_NAME.LINESTRING, EXTENSION_NAME.MULTIPOINT]:
        func = _transpose_chunk_nest_1
    elif extension_type_name in [
        EXTENSION_NAME.POLYGON,
        EXTENSION_NAME.MULTILINESTRING,
    ]:
        func = _transpose_chunk_nest_2

    elif extension_type_name == EXTENSION_NAME.MULTIPOLYGON:
        func = _transpose_chunk_nest_3
    else:
        raise ValueError(f"Unexpected extension type name {extension_type_name!r}")

    arrays = [func(chunk) for chunk in column.chunks]
    return ChunkedArray(
        arrays,
        type=arrays[0].field.with_metadata(column.field.metadata),
    )


def _transpose_coords(arr: Array) -> Array:
    if DataType.is_fixed_size_list(arr.type):
        return arr

    if arr.type.num_fields == 2:
        x = struct_field(arr, [0]).to_numpy()
        y = struct_field(arr, [1]).to_numpy()
        coords = np.column_stack([x, y]).ravel("C")
        return fixed_size_list_array(
            coords,
            2,
            type=coord_storage_type(interleaved=True, dims=CoordinateDimension.XY),
        )

    if arr.type.num_fields == 3:
        x = struct_field(arr, [0]).to_numpy()
        y = struct_field(arr, [1]).to_numpy()
        z = struct_field(arr, [2]).to_numpy()
        coords = np.column_stack([x, y, z]).ravel("C")
        return fixed_size_list_array(
            coords,
            3,
            type=coord_storage_type(interleaved=True, dims=CoordinateDimension.XYZ),
        )

    raise ValueError(f"Expected struct with 2 or 3 fields, got {arr.type.num_fields}")


def _transpose_chunk_nest_0(arr: Array) -> Array:
    return _map_coords_nest_0(arr, _transpose_coords)


def _transpose_chunk_nest_1(arr: Array) -> Array:
    return _map_coords_nest_1(arr, _transpose_coords)


def _transpose_chunk_nest_2(arr: Array) -> Array:
    return _map_coords_nest_2(arr, _transpose_coords)


def _transpose_chunk_nest_3(arr: Array) -> Array:
    return _map_coords_nest_3(arr, _transpose_coords)
